/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.sdk.jmh.util;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Random;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.infra.Blackhole;

/** Benchmarks for {@link org.apache.beam.sdk.util.VarInt} and variants. */
@OperationsPerInvocation(VarIntBenchmark.VALUES_PER_INVOCATION)
public class VarIntBenchmark {
  static final int VALUES_PER_INVOCATION = 2048;
  private static final Random RNG = new Random(314159);

  /** Output to {@link Blackhole}. Do nothing, assume nothing. */
  @State(Scope.Benchmark)
  public static class BlackholeOutput {
    OutputStream stream;

    @Setup
    public void setup(Blackhole bh) {
      stream =
          new OutputStream() {
            @Override
            public void write(int b) {
              bh.consume(b);
            }

            @Override
            public void write(byte[] b) throws IOException {
              bh.consume(b);
            }

            @Override
            public void write(byte[] b, int off, int len) throws IOException {
              bh.consume(b);
            }
          };
    }
  }

  /** Output to {@link ByteStringOutputStream}. */
  @State(Scope.Thread)
  public static class ByteStringOutput {
    final ByteStringOutputStream stream = new ByteStringOutputStream();

    // Unfortunately, this needs to be cleaned up after use to avoid OOMs.
    // It's not generally recommended to use Level.Invocation, but there's no way around it.
    @TearDown(Level.Invocation)
    public void tearDown(Blackhole bh) {
      bh.consume(stream.toByteStringAndReset());
    }
  }

  /** Input from randomly generated bytes. */
  @State(Scope.Benchmark)
  public static class Bytes {
    long[] values = new long[VALUES_PER_INVOCATION];

    @Setup
    public void setup() {
      values = new long[VALUES_PER_INVOCATION];
      byte[] bytes = new byte[VALUES_PER_INVOCATION];
      RNG.nextBytes(bytes);

      for (int i = 0; i < VALUES_PER_INVOCATION; i++) {
        values[i] = (long) (bytes[i] & 0x7F);
      }
    }
  }

  /** Input from randomly generated longs. */
  @State(Scope.Benchmark)
  public static class Longs {
    long[] values = new long[VALUES_PER_INVOCATION];

    @Setup
    public void setup() {
      values = new long[VALUES_PER_INVOCATION];

      for (int i = 0; i < VALUES_PER_INVOCATION; i++) {
        // This gaussian random is used to determine the encoded output size of the sample.
        // The distribution value is tweaked to favor small integers, positive more so than
        // negative.
        double g = RNG.nextGaussian();
        double s = 3;
        g = 10 * Math.min(Math.abs(g < 0 ? g + s : g / (s / 2)), s) / s;

        // Construct a bitmask to keep up to numBits of the input.
        // Find the lowest bit to set in the 7 bit segment below numBits.
        int numBits = 7 * (int) g;
        long mask = ~(~0x7fL << numBits);
        long low = 1L << numBits;

        values[i] = (RNG.nextLong() & mask) | low;
      }
    }
  }

  // Used in Beam 2.52.0
  static void encodeDoLoop(long v, OutputStream stream) throws IOException {
    do {
      // Encode next 7 bits + terminator bit
      long bits = v & 0x7F;
      v >>>= 7;
      byte b = (byte) (bits | ((v != 0) ? 0x80 : 0));
      stream.write(b);
    } while (v != 0);
  }

  // A tweak of the above, replacing a compare with a few bitwise operations.
  static void encodeDoLoopTwiddle(long v, OutputStream stream) throws IOException {
    do {
      // Encode next 7 bits + terminator bit
      long bits = v & 0x7F;
      v >>>= 7;
      long cont = (-v >> 63) & 0x80;
      byte b = (byte) (bits | cont);
      stream.write(b);
    } while (v != 0);
  }

  // Use a mask check to do less work for 1 byte output.
  static void encodeLoop(long v, OutputStream stream) throws IOException {
    while ((v & ~0x7FL) != 0) {
      stream.write((byte) (v | 0x80));
      v >>>= 7;
    }
    stream.write((byte) v);
  }

  // As above, but unrolled.
  static void encodeUnrolled(long v, OutputStream stream) throws IOException {
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    if ((v & ~0x7F) == 0) {
      stream.write((byte) v);
      return;
    }
    stream.write((byte) (v | 0x80));
    v >>>= 7;
    stream.write((byte) (v));
  }

  @Benchmark
  public void encodeDoLoopBlackhole(Longs input, BlackholeOutput output) throws IOException {
    for (long l : input.values) {
      encodeDoLoop(l, output.stream);
    }
  }

  @Benchmark
  public void encodeDoLoopByteString(Longs input, ByteStringOutput output) throws IOException {
    for (long l : input.values) {
      encodeDoLoop(l, output.stream);
    }
  }

  @Benchmark
  public void encodeDoLoopTwiddleBlackhole(Longs input, BlackholeOutput output) throws IOException {
    for (long l : input.values) {
      encodeDoLoopTwiddle(l, output.stream);
    }
  }

  @Benchmark
  public void encodeDoLoopTwiddleByteString(Longs input, ByteStringOutput output)
      throws IOException {
    for (long l : input.values) {
      encodeDoLoopTwiddle(l, output.stream);
    }
  }

  @Benchmark
  public void encodeLoopBlackhole(Longs input, BlackholeOutput output) throws IOException {
    for (long l : input.values) {
      encodeLoop(l, output.stream);
    }
  }

  @Benchmark
  public void encodeLoopByteString(Longs input, ByteStringOutput output) throws IOException {
    for (long l : input.values) {
      encodeLoop(l, output.stream);
    }
  }

  @Benchmark
  public void encodeUnrolledBlackhole(Longs input, BlackholeOutput output) throws IOException {
    for (long l : input.values) {
      encodeUnrolled(l, output.stream);
    }
  }

  @Benchmark
  public void encodeUnrolledByteString(Longs input, ByteStringOutput output) throws IOException {
    for (long l : input.values) {
      encodeUnrolled(l, output.stream);
    }
  }

  @Benchmark
  public void singleByteEncodeDoLoopByteString(Bytes input, ByteStringOutput output)
      throws IOException {
    for (long l : input.values) {
      encodeDoLoop(l, output.stream);
    }
  }

  @Benchmark
  public void singleByteEncodeDoLoopTwiddleByteString(Bytes input, ByteStringOutput output)
      throws IOException {
    for (long l : input.values) {
      encodeDoLoopTwiddle(l, output.stream);
    }
  }

  @Benchmark
  public void singleByteEncodeLoopByteString(Bytes input, ByteStringOutput output)
      throws IOException {
    for (long l : input.values) {
      encodeLoop(l, output.stream);
    }
  }

  @Benchmark
  public void singleByteEncodeUnrolledByteString(Bytes input, ByteStringOutput output)
      throws IOException {
    for (long l : input.values) {
      encodeUnrolled(l, output.stream);
    }
  }
}
